import argparse
import json
from typing import Any
from collections import defaultdict
import copy
import tiktoken
from beartype import beartype
from agent import Agent
from browser_env import Trajectory
from browser_env.actions import (
    Action,
    ActionParsingError,
    create_id_based_action,
    create_none_action,
    create_playwright_action,
)
from browser_env.utils import Observation, StateInfo

import os
import random
import numpy as np
import torch
import transformers
import re
import html

from lxml import html as lxml_html
from lxml import etree
import math
from openai import OpenAI


IP_ADDR = os.environ.get("IP_ADDR", "")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
MAX_CTX_LEN = 32768 * 3
MAX_GPT_LEN = 500000
MAX_RECURSION = 16000
PREPEND = "Help achieve the objective by generating the next step."
DEFAULT_EOS_TOKEN = "<|im_end|><|endoftext|>"



class GPTAgent:
    """prompt-based agent that emits action given the history"""

    def __init__(self, intent_prompt=""):
        self.client = OpenAI()
        self.temperature = 0.7
        self.top_p = 0.95
        self.model = "gpt-4o"

        self.reflect_prompt = "You are an autonomous agent helping users to solve web-based tasks. These tasks will be accomplished through series of actions. \
The information you'll have includes:\n\
- The user's objective\n\
- The current web page's URL\n\
- The current web page's accessibility tree\n\
- Previous steps performed by the user, where each step includes a description of the action and the target web element\n\
- Several proposed next steps,'' labeled by \"No.\"\n\
Your goal is to select the best next step that can complete the task and output this candidate's number, follow the following rules:\n\
- Do not repeat previous steps\n\
- Reject candidates with incorrection intentions, e.g., searching for an item different from the one specified in the objective\n\
- Reject candidates with factual errors, e.g., the description and the chosen web target do not match\n\
- Only output a single number after to represent the selected candidate but not explanation\nNow analyze the following case.\n"

        self.map_prompt = "You are an autonomous agent helping users to solve web-based tasks. These tasks will be accomplished through series of actions. \
The information you'll have includes:\n\
- The user's objective\n\
- The current web page's URL\n\
- A snippit of the current web page's HTML\n\
- A snippit of the current web page's accessibility tree\n\
- Previous steps performed by the user\n\
Your goal is to translate a proposed next step, which consists of an action and a HTML element, into the following format:\n\
- `click [accessibility tree id]`: This action clicks on an interactive (non-static) element with a specific id. Note this id is the number inside \"[]\" in the accessibility tree, not the HTML attribute \"node\". Brackets are required in the response. For example, a valid response is ``click [1234]``\n\
- `type [accessibility tree id] [content]`: Use this to type the content into the field with a specific id in the accessibility tree. For example, a valid response is ``type [1234] [New York]``. The second bracket should include everything that needs to appear in the textbox, but not only the added content. Do not change the letter case\n\
- `press [key_comb]`: Simulates pressing a key combination on the keyboard (e.g., press [PageDown], press [Enter])\n\
- `go back`: Return this when the current web page does not contain useful information and the user should go back to the previous web page\n\
When mapping the next step into actions in the above formats, follow the following rules:\n\
- Take the user's objective into consideration, so the action must help complete the task\n\
- Do not repeat previous steps\n\
- Only output a single step in the above format but not explanation\n\
Now analyze the following case.\n"

        self.stop_prompt = "You are an autonomous agent helping users to solve web-based tasks. These tasks will be accomplished through series of actions. \
The information you'll have includes:\n\
- The user's task, including a high-level objective and a more detailed illustration\n\
- The current web page's URL and accessibility tree\n\
- Previous steps performed by the user, where each step includes a description of the action and the target web element\n\
You will decide whether the task specified by the high-level objective is completed (which means the **last** step of the detailed instruction is completed and the current webpage completes the task) and respond \"completed\" or \"incomplete\". If the task requires returning a number or a string and the answer can be obtained in the current webpage, reply \"completed, [answer]\" where \"[answer]\" is the number or string. If the task requires finding a webpage and the current webpage satisfies the requirement, reply \"completed, [answer]\"  where \"[answer]\" is the current URL. Now analyze the following case. First provide the reasonings. Then summarize the answer with \"Summary:\", followed by \"completed\" or \"incomplete\", followed by the answer to the question if applicable. Do not include newlines after \"Summary:\".\n\n"

        self.intent_prompt = intent_prompt

        self.nextstep_prompt = """You are an autonomous intelligent agent tasked with solving web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.
Here's the information you'll have:\n\
- The user's objective: This is the task you're trying to complete.\n\
- The current web page's URL: This is the page you're currently navigating.\n\
- The current web page's HTML: Each element is assigned with an unique ID, denoted by the attribute \"node\".\n\
The actions you can perform include:\n\
- mouse_click_action: click\n\
- keyboard_sequence_action: type a sequence of characters\n\
- keyboard_combination_action: press a set of keys together (e.g., hotkey like ctrl+c)\n\
You will generate a step-by-step guide to complete the task based on the given information. At each step, you can perform only one action to one web element. The output should be in the correct format: a single step consisting of a text description, an action, as well as the node and HTML of the target web element to perform the action. Be coherent, concise, and accurate in your response. Do NOT use any special characters (e.g., "*", "#", etc.) in your response. Follow EXACTLY the format of the response below.\n\
Here is one example:\n\
Description: click \"Users\"\n\
Action: mouse_click_action\n\
Node: 93\n\
Target: <a class=\"slds-tree__item-label\" node=\"93\">\n\
Now complete the following task by generating a single next step.\n\
"""
        self.note = ""
    
    def reflect(self, prompt, goal):

        print("\n" + "-"*15, "CALLING GPT AGENT FOR", goal.upper(), "-"*15)
        if goal == "stop":
            sys_prompt = self.stop_prompt
        elif goal == "reflect":
            sys_prompt = self.reflect_prompt
        elif goal == "intent":
            sys_prompt = self.intent_prompt
        elif goal == "map":
            sys_prompt = self.map_prompt
        else:
            sys_prompt = self.nextstep_prompt

        messages = [
                {"role": "system", "content": "You are a helpful assistant designed to solve web-based tasks"},
                {"role": "user", "content": (sys_prompt + prompt)},
            ]

        if goal == "stop" and self.note != "":
            messages = [
                {"role": "system", "content": "You are a helpful assistant designed to solve web-based tasks"},
                {"role": "user", "content": sys_prompt[:re.search("You will decide whether the task", sys_prompt).start()] + "Here's the note from previous steps: " + self.note + "\n" + sys_prompt[re.search("You will decide whether the task", sys_prompt).start():] + prompt},
            ]

        print("[PROMPT]", sys_prompt + prompt)
        
        response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
            )

        generated_text = response.choices[0].message.content
        num_total_tokens = response.usage.total_tokens
        print("\n[RESPONSE]", generated_text)
        flag, action = self.analyze(generated_text, goal)
        print("\n[RESULTS]", flag, action)

        return flag, action


    def analyze(self, generated_text, goal):
        if goal == "stop":
            generated_text = generated_text.replace("**","")
            if generated_text.rfind("Summary:"):
                generated_text = generated_text[generated_text.rfind("Summary:"):].replace("Summary:\n", "Summary: ").replace("Summary: \n", "Summary: ")
            can_stop = "completed" in generated_text.lower() and "incomplete" not in generated_text.lower()
            generated_text = re.sub(r"\s+", " ", generated_text.replace("Summary:", "").strip())
            
            if not can_stop:
                print("[Add note]",generated_text)
                self.note += generated_text

            return can_stop, generated_text
            
        elif goal == "reflect":
            digitstr = ""
            for g in generated_text.lower():
                if g.isdigit() or g == "-":
                    digitstr += g
                elif len(digitstr) > 0:
                    break
            if digitstr == "":
                return -1, []
            return int(digitstr), []

        elif goal == "intent":
            return True, generated_text.replace("Detailed Task Objective: ", "")

        elif goal == "nextstep":
            return True, generated_text

        else:
            if "stop" in generated_text.lower():
                return True, "stop"
            elif "go back" in generated_text.lower():
                return True, "go back"

            for w in generated_text.split():
                w = w.lower()
                if w == "click":
                    sidx = re.search("\\[", generated_text).start()
                    generated_text = generated_text[sidx:]
                    eidx = re.search("\\]", generated_text).end()
                    return True, "click " + generated_text[:eidx]
  
                elif w == "type":
                    sidx = re.search("\\[", generated_text).start()
                    generated_text = generated_text[sidx:]
                    eidx = re.search("\\]", generated_text).end()
                    astr = generated_text[:eidx]
                    generated_text = generated_text[eidx:]
                    eidx = re.search("\\]", generated_text).end()
                    return True, "type " + astr + generated_text[:eidx]

                elif w == "press":
                    sidx = re.search("\\[", generated_text).start()
                    generated_text = generated_text[sidx:]
                    eidx = re.search("\\]", generated_text).end()
                    return True, "press " + generated_text[:eidx]

            return False, "scroll [down]"


class WorkflowAgent(Agent):
    """prompt-based agent that emits action given the history"""

    @beartype
    def __init__(
        self,
        output_dir: str
    ) -> None:
        super().__init__()

        config = transformers.AutoConfig.from_pretrained(
    		"Qwen/Qwen2-7B-Instruct",
    		cache_dir="/home/ubuntu/.cache/huggingface/hub",
    		token=HF_TOKEN
    	)
        model = transformers.AutoModelForCausalLM.from_pretrained(
    		"Qwen/Qwen2-7B-Instruct",
    		config=config,
    		cache_dir="/home/ubuntu/.cache/huggingface/hub",
    		torch_dtype=torch.bfloat16,
    		token=HF_TOKEN
    	)
        tokenizer = transformers.AutoTokenizer.from_pretrained(
    		"Qwen/Qwen2-7B-Instruct",
    		cache_dir="/home/ubuntu/.cache/huggingface/hub",
    		model_max_length=MAX_CTX_LEN,
    		padding_side="left",
    		token=HF_TOKEN
    	)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.padding_side = "left"

        model.load_adapter(output_dir)
        model.load_state_dict(torch.load(f"{output_dir}/model.pt"), strict=False)
        model.eval().cuda()

        self.model = model
        self.tokenizer = tokenizer
        self.user = GPTAgent()

        self.prev_actions = ""
        self.stop = False
        self.intent = None
        self.last_action = ""
        self.subsequent_actions = []
        self.num_prev_actions = 0

        self.expanding = []
        self.prev_actions_acc = ""
        self.prevacc = None
        self.goback = False
        self.consecutive_scroll = 0
        self.domain = None
        self.max_text_len = 100

    def set_deterministic(self, seed):
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed) 
        torch.cuda.manual_seed_all(seed)

    def eval(self, doms, seed = 0, max_try = 10):
        target_predictions = defaultdict(int)
        responses = defaultdict(list)
        print("\n" + ("-"*15), "CALLING WORKFLOW AGENT FOR STEP", self.num_prev_actions + 1, "-"*15)
        
        for chunkidx, raw_obs in enumerate(doms):
            self.set_deterministic(seed)

            raw_inp = self.task_meta_info + "Observation: " + raw_obs + "\nStep-by-step guide:\n" + self.prev_actions
            print("[DOM CHUNK",chunkidx+1, "OUT OF", len(doms), "]")
            print(raw_inp)

            messages = [
                    {"role": "system", "content": PREPEND},
                    {"role": "user", "content": raw_inp}
                ]
            input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            input_text = self.tokenizer(input_text, return_tensors="pt")
            model_inputs = {}
            for key, value in input_text.items():
                model_inputs[key] = input_text[key].to(self.model.device).reshape(1, -1)
        
            input_len = model_inputs["input_ids"].shape[1]        

            num_try = 0
            
            while num_try < max_try:

                if self.num_prev_actions >= 1 and ("hasPopup: menu expanded: False" in self.last_action or "label=" in self.last_action):
                    action = str(self.num_prev_actions+1) + ".\nDescription: click dropdown item"
                    action_inputs = self.tokenizer(action, add_special_tokens=False, return_tensors="pt")
                    for key, value in model_inputs.items():
                        model_inputs[key] = torch.cat([model_inputs[key], action_inputs[key].to(self.model.device)], -1)

                generated_ids = self.model.generate(**model_inputs, max_new_tokens=200, do_sample=True, top_p=0.95, temperature=0.6, pad_token_id=self.tokenizer.eos_token_id)
                
                    
                generated_ids = [generated_ids[0][input_len:]]
                generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
                generated_text = str(self.num_prev_actions+1) +".\n" + re.sub(r"\s+", " ", generated_text[re.search("Description: ", generated_text).start():re.search("\nAction: ", generated_text).start()]) + generated_text[re.search("\nAction: ", generated_text).start():]
                generated_text = generated_text[:re.search("Target: ", generated_text).start()] + re.sub(r"\s+", " ", generated_text[re.search("Target: ", generated_text).start():])
                generated_text = generated_text.replace("</s>", "").strip() +"\n"
                print("[CANDIDATE " + str(num_try) + "]")
                print(generated_text)

                match = re.search(r".\nDescription: ([^\n]+)\nAction: ([^\n]+)\nNode: ([^\n]+)\nTarget: (.+)", generated_text)
                if not match:
                    print("[REMOVED] PATTERN NOT MATCH")
                    continue

                sidx = re.search("Node: ", generated_text).end()
                tid = generated_text[sidx:]
                eidx = re.search(r"[ \n]", tid).start()
                tid = tid[:eidx]
                try:
                    int(tid)
                except:
                    try:
                        sidx = re.search("node=\"", generated_text).end()
                        alt_tid = generated_text[sidx:]
                        eidx = re.search("\"", alt_tid).start()
                        tid = int(alt_tid[:eidx])
                    except:
                        print("[INVALID GENERATED ID]")
                        continue

                target_predictions[int(tid)] +=1
                responses[int(tid)].append(generated_text)

                if target_predictions[int(tid)] >= max_try // 2: break

        if len(target_predictions) == 0:
            return ""

    
        if len(target_predictions) == 1:
            target_id_pred = target_predictions.keys()[0]
            responses = sorted(responses[target_id_pred], key=lambda x: len(x)) 
            return responses[-1]

        sorted_target_predictions = sorted(target_predictions.items(), key=lambda x: x[1]) 
        prompt = self.task_meta_info +"Accessbility tree:\n" + self.acc[:MAX_GPT_LEN] + "\nPrevious steps:\n" + (self.prev_actions if len(self.prev_actions) > 0 else "None\n") + "Proposed next steps:\n" 
        idx_id_map = {}
    
        for candidate_idx, (tid, _) in enumerate(sorted_target_predictions):
            candidate = sorted(responses[tid], key=lambda x: len(x))[-1]
            prompt += "Candidate No. " + str(candidate_idx + 1) + ":\n" + candidate[re.search("Description", candidate).start():] + "\n"
            idx_id_map[candidate_idx + 1] = tid
        selected_cidx, _ = self.user.reflect(prompt, "reflect")  
        if selected_cidx == -1:
            return ""
        elif selected_cidx <= len(idx_id_map):
            target_id_pred = idx_id_map[selected_cidx]
            print("[MOST PICK BY GPT]", target_id_pred)
            
        responses = sorted(responses[target_id_pred], key=lambda x: len(x)) 
        return responses[-1]

    def select_nearest_neighbor(self, tree, target, candidates):

        target_node = tree.xpath(f"//*[@node=\"{str(target)}\"]")[0]
        tgt_rep = print_without_children(target_node)
        
        dist = []
        children = []
        similarity = []
        found_exact = False
        candidate_rep = []

        for idx, c in enumerate(candidates):
            
            cnode = tree.xpath(f"//*[@node=\"{str(c)}\"]")[0]
            candidate_rep.append(print_without_children(cnode))
            d, dchild = tree_distance(target_node, cnode)
            overlap_score = calculate_overlap_percentage(candidate_rep[-1], tgt_rep)
            dist.append(d)

            if d == 0:
                found_exact = True
                exact_idx = idx
                break
            if dchild == 0:
                children.append(overlap_score)
                similarity.append(children[-1])
            else:
                children.append(0)
                similarity.append(overlap_score)

        notsure = True
        if found_exact:
            cidx = exact_idx
            print("[NN FOUND EXACT ELEMENT]", cidx)
            notsure = False
        elif len(children) > 0 and np.max(children) > 0:
            cidx = np.argmax(children)
            print("[NN FOUND MOST SIMILAR CHILD]", cidx)
        else:
            cidx = np.argmin(dist)
            if (np.array(dist) == dist[cidx]).sum() > 1:
                sim = np.array(similarity)[np.array(dist) == dist[cidx]]
                cidx2 = np.argmax(sim)
                count = 0
                cidx = 0
                while True:
                    if dist[cidx] == np.min(dist) and count == cidx2:
                        break
                    elif dist[cidx] == np.min(dist):
                        count += 1
                    cidx += 1
            print("[NN FOUND SOMETHING ELSE]", cidx)

        return candidates[cidx], candidate_rep[cidx], notsure



    def parse(self, response):
        print("\n", "-"*15, "MAPPING OUTPUT TO ACTION", "-"*15)
        print(response)

        inv_full_id_map = {v: k for k, v in self.full_id_map.items()}
        inv_acc_id_map = {v: k for k, v in self.acc_id_map.items()}
        inv_cleaned_id_map = {v: k for k, v in self.cleaned_id_map.items()}
        full_tree = lxml_html.fromstring(self.full_dom)
        tree = lxml_html.fromstring(self.dom)

        deidx = re.search("\nAction: ", response).start()
        asidx = re.search("Action: ", response).end()
        aeidx = re.search("\nNode: ", response).end()
        action = response[asidx:aeidx]
        target = response[aeidx:]
        teidx = re.search(" ", target).start()
        target = target[:teidx]

        parsed_response = ""
        if "click" in action:
            parsed_response += "click"
        elif "sequence" in action or "type" in response[:deidx].lower():
            parsed_response += "type"
        else:
            parsed_response += "press"

        
        window_size = 50
        all_nodes = sorted(self.cleaned_id_map.keys(), key=lambda x: int(x)) 
        print(all_nodes,target)
        target_list_idx = all_nodes.index(target)
        target_low = max(0, int(target_list_idx) - window_size)
        target_high = min(len(all_nodes), int(target_list_idx) + window_size)

        dom_start = re.search("node=\"" + str(target_list_idx[target_high]) + "\"", self.dom).start()
        dom_end = re.search("node=\"" + str(target_list_idx[target_low]) + "\"", self.dom).end()
        snippet = self.dom[dom_start:dom_end]

        prompt = self.task_meta_info + "HTML snippit:\n" + snippet + "\nAccessibility tree:\n" + acc[:MAX_GPT_LEN] + "\nPrevious steps:\n" + (self.prev_actions_acc if len(self.prev_actions_acc) > 0 else "None\n") + "\nProposed next step:\n" + response[re.search("Description", response).start():]                
        _, action = self.user.reflect(prompt, "map")

        if action in ["scroll [down]", "stop"]:
            return action, []

        if action == "go back":
            
            self.prev_actions += str(self.num_prev_actions + 1) + ".\nDescription: This web page does not contain useful information. Go back to previous page\nAction: mouse_click_action\nNode: 9 9 9 9 9\nTarget: <div node=\"9\">Go Back</div>\n"
            self.num_prev_actions += 1
            self.stop = False
            return "go back", []

  
        print("[SELECTED ACTION]", action)
        selected_des = self.acc[re.search("\\["+action_target_acc+"\\]", self.acc).end():]
        selected_des_list = []
        for aidx, a in enumerate(actions):

            if a[:4] == "type" or a[:5] == "click":

                selected_acc = a[re.search("\\[", a).end():re.search("\\]", a).start()]
                withbrackets = "\\["+selected_acc+"\\]"              
                
                if re.search(withbrackets, acc) is not None:
                    selected_des = acc[re.search(withbrackets, acc).end():]
                    if "\n" in selected_des:
                        selected_des = selected_des[:re.search("\n", selected_des).start()].strip()                        
 
                    alreadyexpanded = False
                    if True:
                        if inv_acc_id_map[selected_acc] in inv_cleaned_id_map.keys():
                            selected_eid = inv_cleaned_id_map[inv_acc_id_map[selected_acc]]
                            try:
                                selected_eid_rep = print_without_children(cleaned_tree.xpath(f"//*[@node=\"{selected_eid}\"]")[0])
                                selected_eid_rep=selected_eid_rep[:selected_eid_rep.find(">")+1]
                            except:
                                return "scroll [down]", []
                        else:

                            selected_eid = inv_full_id_map[inv_acc_id_map[selected_acc]]
                            selected_node = tree.xpath(f"//*[@node=\"{selected_eid}\"]")[0]
                            selected_eid_rep = print_without_children(selected_node)
                            cscores = []
                            ceids = []
                            for child in selected_node.getchildren():
                                childbe = full_id_map[child.attrib["node"]]
                                if childbe in inv_cleaned_id_map:
                                    childeid = inv_cleaned_id_map[childbe]
                                    childeid_rep = print_without_children(cleaned_tree.xpath(f"//*[@node=\"{childeid}\"]")[0])
                                    ceids.append(childeid)
                                    cscores.append(calculate_overlap_percentage(childeid_rep, selected_eid_rep))
                            if len(cscores) > 0:
                                selected_eid = ceids[np.argmax(cscores)]
                                selected_eid_rep = print_without_children(cleaned_tree.xpath(f"//*[@node=\"{selected_eid}\"]")[0])

                        if a.split()[0] == "type":
                            revise_des = "Type in " + selected_des + ". Enter the content: " + a[re.search("\\] \\[", a).end():-1]
                        else:
                            revise_des = a.split()[0] + " " + selected_des

                        try:
                            selected_des_list.append(revise_des + " (Target HTML: " + generated_text[re.search("Target: ", generated_text).end():].replace("\n"," ").strip()+")")
                        except:
                            selected_des_list.append(revise_des)
                        
                                
                        action = "keyboard_sequence_action" if a[:4] == "type" else "mouse_click_action"
                        revised_step = "\nAction: " + action +"\nNode: " + (selected_eid + " ")+ (selected_eid + " ")+ (selected_eid + " ")+ (selected_eid + " ")+ (selected_eid )  + "\nTarget: " +  selected_eid_rep 

                        revised_step = str(self.num_prev_actions + 1) + ".\nDescription: " + revise_des + revised_step +"\n"
                        print("[REVISE ACTION " + str(aidx) + "]", revised_step)

                self.prev_actions += revised_step.strip() + "\n"
                self.num_prev_actions += 1
            elif (a[:5] == "press") and aidx == 0:
                self.prev_actions += generated_text
                self.num_prev_actions += 1
                selected_des_list.append(generated_text[re.search("Description: ", generated_text).end():re.search("\nAction: ", generated_text).start()])
            
            if aidx > 0:
                self.subsequent_actions.append(a)

        if len(selected_des_list) > 0:
            self.prev_actions_acc += "\n".join(selected_des_list) + "\n"
            self.last_action = actions[0]

        action_target_acc = re.findall(r'\d+', action)[0]
        node_info = obs_nodes_info[action_target_acc]
        node_bound = node_info["union_bound"]
        x, y, width, height = node_bound
        center_x = x + width / 2
        center_y = y + height / 2
        newy = int(center_y / 720.0)
        if newy >= 1:
            self.subsequent_actions.append(action)
            if newy == 1:
                action = "scroll [down]"
            else:
                action = "scroll [down"+str(newy)+"]"      
        elif newy < 0 :
            self.subsequent_actions.append(action)
            newy = -newy + 1
            if newy == 1:
                action = "scroll [up]"
            else:
                aaction = "scroll [up"+str(newy)+"]"

        return actions[0], selected_des_list


    def set_domain(self):
        if ":9999" in self.url:
            self.domain = "reddit"
        elif ":7770" in self.url:
            self.domain = "shopping"
        elif ":8023" in self.url:
            self.domain = "git"
        elif ":7780" in self.url:
            self.domain = "admin"
        else:
            self.domain = "map"


    def replace_urls(self):
        for attr in ['dom', 'full_dom', 'url', 'acc']:
            setattr(self, attr, getattr(self, attr).replace("http://metis.lti.cs.cmu.edu", IP_ADDR))

        url_dict = {"reddit": {(IP_ADDR + ":9999"): "https://www.postmill.xyz", "reddit": "postmill", "Reddit": "Postmill"},
                    "shopping": {(IP_ADDR + ":7770"): "https://www.onestopmarket.com"},
                    "git": {(IP_ADDR + ":8023"): "https://www.gitlab.com"}, 
                    "admin": {(IP_ADDR + ":7780"): "https://www.magento.com", "/admin/admin": "/admin"},
                    "map": {(IP_ADDR + ":3000"): "https://www.openstreetmap.org"}}
        
        for attr in ['dom', 'full_dom', 'url', 'acc']:
            for url_src, url_tgt in url_dict[self.domain].items():
                setattr(self, attr, getattr(self, attr).replace(url_src, url_tgt))


    def revise_intent(self, intent):
        if self.domain == "reddit":
            intent = intent.replace(" reddit ", " postmill ").replace("reddit","forum").replace("Reddit", "Postmill")
            
        self.ori_intent = intent
        prompt = intent + " Here's the current webpage for reference:\n" + self.acc
        _, intent = self.user.reflect(prompt, "intent")

        if intent[-1].isalnum():
            intent += "."
        intent = self.ori_intent + " Specifically: " + intent
        self.intent = intent
        print("[REVISE INTENT]", self.intent)

    def execute_remaining_actions(self):
        parsed_response = self.subsequent_actions[0]
        breakflag = False

        if self.goback:
            prevnum = re.findall(r'\d+', parsed_response)[0]
            diff = int(re.findall(r'\d+', self.acc)[0]) - int(re.findall(r'\d+', self.prevacc)[0]) 
            newnum = int(prevnum) + diff
            parsed_response = parsed_response.replace(prevnum, str(newnum))
            self.subsequent_actions[0] = parsed_response

        if "scroll" not in parsed_response and re.search(r'\d+', parsed_response):
            action_target = re.findall(r'\d+', parsed_response)[0]
            
            if action_target not in self.obs_nodes_info.keys():
                try:
                    model_selected_acc = self.acc_id_map[self.cleaned_id_map[action_target]]
                    parsed_response = parsed_response.replace(action_target, model_selected_acc)
                    self.subsequent_actions[0] = parsed_response
                    action_target = re.findall(r'\d+', parsed_response)[0]
                except:
                    self.subsequent_actions = self.subsequent_actions[1:]
                    breakflag = True
            if not breakflag:

                node_info = self.obs_nodes_info[action_target]
                node_bound = node_info["union_bound"]
                x, y, width, height = node_bound
                center_x = x + width / 2
                center_y = y + height / 2
                newy = int(center_y / 720.0)
                if newy >= 1:
                    self.subsequent_actions.insert(0, "scroll [down]" if newy == 1 else "scroll [down"+str(newy)+"]")
                elif newy < 0 :
                    newy = -newy + 1
                    self.subsequent_actions.insert(0, "scroll [up]" if newy == 1 else "scroll [up"+str(newy)+"]")

                self.consecutive_scroll += 1

        if not breakflag:
            parsed_response = self.subsequent_actions[0]
            self.subsequent_actions = self.subsequent_actions[1:]
            self.last_action = parsed_response
            print("[EXECUTE REMAINING ACTION]", parsed_response)
            return create_id_based_action(parsed_response)

        return None

    def truncate_text(self):
        if self.max_text_len != -1:
            alllines = re.split("(\"> )", self.dom)
            lines = []
            dom_trunc = ""
            for lidx in range(len(alllines)):
                has_text = re.search("<", alllines[lidx])
                if has_text:
                    text_part = alllines[lidx][:has_text.start()]
                    if len(text_part) > self.max_text_len:
                        text_part = text_part[:self.max_text_len]
                        dom_trunc += text_part + alllines[lidx][has_text.start() - 1:] 
                        continue
                dom_trunc += alllines[lidx] 
            self.dom = re.sub(r"\s+", " ", dom_trunc.strip())

    @beartype
    def next_action(
        self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any]
    ) -> Action:
        print("\n\n" + "="*25,"NEW STEP","="*25) 
        print("[NUM PREV STEPS]", self.num_prev_actions)
        print("[PREV ACTIONS]", self.prev_actions)
        self.dom = trajectory[-1]["info"]["observation_metadata"]["text"]["html"]
        self.full_dom = trajectory[-1]["info"]["observation_metadata"]["text"]["obs_nodes_info_html"]["full_html"]
        self.acc = trajectory[-1]["info"]["observation_metadata"]["text"]["obs_nodes_info_html"]["acc"]
        self.url = trajectory[-1]["info"]["page"].url
        self.set_domain()
        
        
        self.acc = self.clean_string(self.acc, True)
        self.dom = self.clean_string(self.dom)
        self.full_dom = self.clean_string(self.full_dom)
        self.replace_urls()
        self.truncate_text()

        # if self.intent is None:
        #     self.revise_intent(intent)
        self.intent = self.ori_intent = intent
            
        self.full_id_map = trajectory[-1]["info"]["observation_metadata"]["text"]["obs_nodes_info_html"]["full_id_map"]
        self.cleaned_id_map = trajectory[-1]["info"]["observation_metadata"]["text"]["obs_nodes_info_html"]["cleaned_id_map"]
        self.acc_id_map = trajectory[-1]["info"]["observation_metadata"]["text"]["obs_nodes_info_html"]["acc_id_map"]
        self.obs_nodes_info = trajectory[-1]["info"]["observation_metadata"]["text"]["obs_nodes_info"]  
        

        print("[AXTREE]")
        print(self.acc)

        if self.consecutive_scroll >= 5:
            self.subsequent_actions = []

        if len(self.subsequent_actions) > 0:
            action = self.execute_remaining_actions()
            if action:
                return action
            
        
        self.task_meta_info = "Objective: " + self.intent + "\nURL: " + self.url + "\n" 

        if self.num_prev_actions > 1:
            stop_prompt = self.task_meta_info +"Accessibility tree:\n" + self.acc + "\nUser's actions:\n" + self.prev_actions_acc
            should_stop, ans = self.user.reflect(stop_prompt, "stop")

            if should_stop:
                print("[ISSUE STOP]", ans)
                parsed_response = "stop" + "[" + ans + "]"
                return create_id_based_action(parsed_response)
                
        cur_len = len(self.tokenizer(self.task_meta_info + self.prev_actions)["input_ids"])
        windowsize = MAX_CTX_LEN - cur_len
        obs_len = len(self.tokenizer(self.dom)["input_ids"])

        if obs_len > windowsize:
            doms = []
            alllines = re.split("(</[a-z]+> <[a-z])", self.dom)                        
            lines = []
            for lidx in range(len(alllines)):
                if lidx % 2 == 1:
                    lines.append(alllines[lidx - 1] + alllines[lidx])
                else:
                    if lidx == len(alllines) - 1:
                        lines.append(alllines[lidx])
            
            num_iter = 0
            prev_remaining = ""
            dom = ""
            while len(lines) > 0 and num_iter < MAX_RECURSION:
                num_iter += 1

                if lines[0][-4:-1] == "> <":
                    line_to_add = lines[0][:-3]
                    remaining = lines[0][-2:]
                else:
                    line_to_add = lines[0]
                    remaining = ""

                dom_new = dom + prev_remaining + line_to_add
                new_len = len(self.tokenizer(dom_new)["input_ids"]) 
                if new_len > windowsize:
                    if dom == "":
                        lines = lines[1:]
                        rev_remaining = remaining
                    else:
                        doms.append(dom)
                        dom = ""
                else:
                    dom = dom_new
                    lines = lines[1:]
                    prev_remaining = remaining

            if len(dom) > 0:
                doms.append(dom)

        else:
            doms = [self.dom]

            
        print("-"*15,"CHUNK DOM INTO", len(doms), "PIECES", "-"*15)

        response = ""
        num_try = 0
        while len(response) == 0 and num_try < 5:
            response = self.eval(doms, seed=num_try)
            num_try += 1

        if len(response) == 0:
            return create_id_based_action("scroll [down]")
       
        parsed_response, parsed_response_full = self.parse(response)

        if parsed_response[0] == "stop":
            stop_prompt = "Objective: " + self.ori_intent + "\nDetailed instruction: " + self.intent + "\nURL: " + self.url + "\nAccessibility tree:\n" + acc[:MAX_GPT_LEN] + "\nUser's actions:\n" + self.prev_actions_acc
            should_stop, ans = self.user.reflect(stop_prompt, "stop")

            if should_stop:
                print("[ISSUE STOP]", ans)
                parsed_response = "stop" + "[" + ans + "]"
                return create_id_based_action(parsed_response)

        action = create_id_based_action(parsed_response)
        print("-"*15, "FINAL SELECTED ACTION:",parsed_response, "-"*15)
        
        self.prevacc = acc
        self.consecutive_scroll = 0
        return action

    def reset(self, test_config_file: str) -> None:
        self.prev_actions = ""
        self.stop = False
        self.intent = None
        self.last_action = ""
        self.subsequent_actions = []
        self.num_prev_actions = 0
        self.user = GPTAgent()
        self.expanding = []
        self.prev_actions_acc = ""
        self.infoseeking = False
        self.prevacc = None
        self.goback = False
        self.consecutive_scroll = 0
    
    def clean_string(self, target_string, is_axtree=False):
        
        target_string = html.unescape(target_string)
        try:
            target_string = bytes(target_string, "utf-8").decode("unicode_escape")
        except:
            pass
        target_string = target_string.replace("–", '-').replace("•", '-').replace("’", '\'').replace("‹", '<').replace("×", '*').replace("·", '.').replace("”","\"").replace("＋", '+')
        target_string = target_string.replace("&amp;","&").replace("&lt;","<").replace("&gt;",">")
        target_string = re.sub(r'[^\x00-\x7F]+',' ', target_string)
        target_string = re.sub(u'[^\u0020-\uD7FF\u0009\u000A\u000D\uE000-\uFFFD\U00010000-\U0010FFFF]+', ' ', target_string)    
        pattern = re.compile(r'[\ue000-\uf8ff]')
        target_string = pattern.sub('', target_string)
        if is_axtree:
            target_string = re.sub(r"\n([^\n]+)StaticText \'\'\n", "\n", target_string)
            target_string = re.sub(r"\n([^\n]+)LineBreak \'\n\'\n", "\n", target_string)
        else:
            target_string = re.sub(r"\s+", " ", target_string)

        return target_string

##### Helper Funcs #####

def print_without_children(element):
    element_string = f'<{element.tag}'
    for name, value in element.attrib.items():
        element_string += f' {name}="{value}"'
    element_string += '>'

    # Optionally, add element's text if it's not None or empty
    if element.text and element.text.strip():
        element_string += element.text.strip()

    element_string += f'</{element.tag}>'
    return element_string

def find_lowest_common_ancestor(node1, node2):
    """ Find the lowest common ancestor of two nodes """
    ancestors1 = set()
    while node1 is not None:
        ancestors1.add(node1)
        node1 = node1.getparent()
    
    while node2 is not None:
        if node2 in ancestors1:
            return node2
        node2 = node2.getparent()
    
    return None

def tree_distance(node1, node2):
    """ Calculate the tree distance between two nodes """
    lca = find_lowest_common_ancestor(node1, node2)
    
    # Distance from node1 to LCA
    distance1 = 0
    ancestor = node1
    while ancestor != lca:
        ancestor = ancestor.getparent()
        distance1 += 1
    
    # Distance from node2 to LCA
    distance2 = 0
    ancestor = node2
    while ancestor != lca:
        ancestor = ancestor.getparent()
        distance2 += 1
    
    # The total distance is the sum of both distances
    return distance1 * 100 + distance2, distance1

def clean_str(text):
    for symbol in ["*","/","'","\"","(",")","[","]","\\","#","&",".",",",":","?","!", "<", ">", "=", "\"", "'", "-", "_"]:
        text = text.replace(symbol, ' ')
    return text

def calculate_overlap_percentage(sentence1, sentence2):
    # Tokenize the sentences into sets of words, converting to lowercase to ensure case-insensitive comparison
    sentence1 = clean_str(sentence1)
    sentence2 = clean_str(sentence2)
    words1 = set(sentence1.lower().split())
    words2 = set(sentence2.lower().split())
    
    # Find the common words between the two sets
    common_words = words1.intersection(words2)
    
    # Calculate the total number of unique words across both sentences
    total_unique_words = len(words2)
    
    # Calculate the percentage of overlap
    if total_unique_words > 0:  # Prevent division by zero
        overlap_percentage = (len(common_words) / total_unique_words) 
    else:
        overlap_percentage = 0
    
    return overlap_percentage


def construct_workflow_agent(args: argparse.Namespace) -> Agent:
    agent = WorkflowAgent(output_dir=args.model_endpoint)
    
    return agent
